import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

def fun_values(data):
    return data.cpu().detach().numpy()

# time statistics
import time
begin=time.time()

# Set the random seed
LUCKY_NUM = 666
torch.manual_seed(LUCKY_NUM)
torch.cuda.manual_seed(LUCKY_NUM)
np.random.seed(LUCKY_NUM)

# training via GPU
print(torch.cuda.is_available())
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
torch.cuda.set_device(0)

#%% Control parameters
print('0 for initial operation\n 1 for repetitive operation\n please enter:')
num_init=1
print('num init={}'.format(num_init))
# %% Import Data
path_data='./dataset/'
data=np.load(path_data+'data02_dataset_CNNinv.npz')
data_x=data['data_x']
data_y=data['data_y']
data_fc_ori=data['data_fc']
data_fc=data_fc_ori[:,:1096,:]
data_fc[:,:,1]=data_fc[:,:,1]*0.5
data_fix=data['data_fix']
data_obs_loc_ori=data['data_obs_loc']
data_obs_loc=data_obs_loc_ori[:,:1096]
data_scale=data['data_scale']
data_para_range=data['para_range']


# %% Regularization
def fun_norm(data, case=0):
    n_data=data.shape[0]
    if case==0:
        data_mean=np.nanmean(data,axis=0)
        data_std=np.nanstd(data,axis=0)
        data_norm=(data-np.tile(data_mean,(n_data,1,1,1)))/np.tile(data_std,(n_data,1,1,1))
    elif case==1:
        data_max=np.nanmax(data,axis=0)
        data_min=np.nanmin(data,axis=0)
        data_dev=data_max-data_min
        data_norm=(data-np.tile(data_min,(n_data,1,1,1)))/np.tile(data_dev,(n_data,1,1,1))
        data_mean=data_min
        data_std=data_dev        
    return data_norm,(data_mean,data_std)

x_data_norm,(x_mean,x_std)=fun_norm(data_x,case=0)
#y_data_norm,(y_mean,y_std)=fun_norm(data_y[:,:,None,None],case=0)

x_data_norm[np.isnan(x_data_norm)]=0
# y_data_norm[np.isnan(y_data_norm)]=0
   
# %% Randomly split training and validation sets
def fun_split(train_pc, x_data_norm, y_data_norm, fc_data,fix_para,obs_loc,case=1):
    # Ratio used for training train_pc
    n_data = x_data_norm.shape[0]
    n_train=int(n_data*train_pc)
    n_test=int(n_data-n_train)
    print('{} for training and {} for test'.format(n_train, n_test))
    if case ==0:
        data_index=np.arange(n_data)
        np.random.seed(666)
        train_index=np.random.choice(list(data_index),n_train,replace=False)
        test_index=np.array(list(set(data_index).difference(set(train_index))))
        random_index=np.concatenate((train_index,test_index))
        np.savetxt(path_data+'random_index.txt',random_index)
    elif case==1:
        random_index=np.loadtxt(path_data+'random_index.txt')
    def fun_split_inter(data):
        data_train = data[random_index[:n_train].astype(int)]
        data_test  = data[random_index[n_train:].astype(int)]
        return data_train, data_test
    # split data
    x_train,x_test  = fun_split_inter(x_data_norm)
    y_train,y_test  = fun_split_inter(y_data_norm)
    fc_data_train,fc_data_test=fun_split_inter(fc_data)
    fix_para_train,fix_para_test=fun_split_inter(fix_para)
    obs_loc_train,obs_loc_test=fun_split_inter(obs_loc)
    
    return x_train, x_test, y_train, y_test,fc_data_train,fc_data_test,fix_para_train,fix_para_test,obs_loc_train,obs_loc_test

x_train, x_test, y_train, y_test,fc_data_train,fc_data_test,fix_para_train,fix_para_test,obs_loc_train,obs_loc_test=fun_split(0.8,x_data_norm,data_y,data_fc,data_fix, obs_loc=data_obs_loc,case=num_init)

#%% Production of datasets
class ForDataset(Dataset):
    def __init__(self, x, y,fc_data,fix_para,obs_loc,scale_data):
        x_dtype = torch.FloatTensor
        y_dtype = torch.FloatTensor     # for MSE or L1 Loss
        fc_data_dtype=torch.FloatTensor
        fix_para_dtype=torch.FloatTensor
        obs_loc_dtype=torch.IntTensor
        scale_data_dtype=torch.FloatTensor

        self.length = x.shape[0]
        self.x_data = torch.from_numpy(x).type(x_dtype)
        self.y_data = torch.from_numpy(y).type(y_dtype)
        self.fc_data = torch.from_numpy(fc_data).type(fc_data_dtype)
        self.fix_para = torch.from_numpy(fix_para).type(fix_para_dtype)
        self.obs_loc = torch.from_numpy(obs_loc).type(obs_loc_dtype)
        self.scale_data = torch.from_numpy(scale_data).type(scale_data_dtype)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index],self.fc_data[index],self.fix_para[index],self.obs_loc[index],self.scale_data

    def __len__(self):
        return self.length

dataset_train = ForDataset(x=x_train, y=y_train,fc_data=fc_data_train,fix_para=fix_para_train,obs_loc=obs_loc_train,scale_data=data_scale)
dataset_test = ForDataset(x=x_test, y=y_test,fc_data=fc_data_test,fix_para=fix_para_test,obs_loc=obs_loc_test,scale_data=data_scale)

# print("Train set size: ", dataset_train.length)
# print("Test set size: ", dataset_test.length)

    
# %% model design
class swb(torch.nn.Module):
    
    # def __init__(self, **kwargs):
    #     None
        
    def forward(self,fc_data,para):
        # default the np.exp has been used in CNN
        self.timeseries = fc_data
        self.theta_s=para[:,0,None]
        self.theta_fc=para[:,1,None]
        self.theta_wp=para[:,2,None]
        self.zr=para[:,3,None]
        self.zr_factor=para[:,4,None]
        self.p=para[:,5,None]
        self.theta_init=para[:,6,None]
        self.Ksat=para[:,7,None]

        self.ngrid=self.timeseries.shape[0]
        self.ndate=self.timeseries.shape[1]
        self.delta_t=1
        
        # Initialize arrays for the simulation
        if fc_data.device.type=='cuda':
            self.results_save=torch.zeros([self.ngrid,self.ndate,6],requires_grad=True).cuda()
        else:
            self.results_save=torch.zeros([self.ngrid,self.ndate,6]).cpu()
        self.taw = (self.theta_fc - self.theta_wp) * self.zr * self.zr_factor
        self.raw = self.p * self.taw

        # Loop and perform the calculation
        theta_prev=self.theta_init
        dr_prev = self.dr_from_theta(theta_prev)
        n_delta=int(1/self.delta_t)
        shape1=[self.ngrid,n_delta]# shape of one delta t for saving variables pjli
        
        for date in range(self.ndate):
            if fc_data.device.type=='cuda':
                ep=torch.zeros(shape1,requires_grad=True).cuda()
                eta=torch.zeros(shape1,requires_grad=True).cuda()
                ro=torch.zeros(shape1,requires_grad=True).cuda()
                dp=torch.zeros(shape1,requires_grad=True).cuda()
            else:
                ep=torch.zeros(shape1).cpu()
                eta=torch.zeros(shape1).cpu()
                ro=torch.zeros(shape1).cpu()
                dp=torch.zeros(shape1).cpu()
            for i_iter in range(n_delta):
                row = self.timeseries[:,date,:]
                # ks = self.ks(dr_prev)
                # use theta_prev to calculate ks as dp
                ks =self.ks(dr_prev)
                dr_without_irrig, ep[:,i_iter], eta[:,i_iter],ro[:,i_iter],dp[:,i_iter]  = self.dr_without_irrig(dr_prev, theta_prev, ks, row)
                # pjli 20220406 we donot need irrigation modules
                assumed_net_irrigation=0
                dr = self.dr(dr_without_irrig, assumed_net_irrigation)
                theta = self.theta_from_dr(dr)
                theta_prev = theta
                dr_prev = dr
            # print('ks',ks,date,(theta-self.theta_wp)/self.taw*1000*0.9)
            # print(self.theta_fc-dr_prev/0.9/1000,theta_prev
            self.results_save[:,date, 0] = theta.squeeze()
            self.results_save[:,date, 1] = ks.squeeze()
            # self.results_save[:,date, 2] = recommended_net_irrigation
            # self.results_save[:,date, 3] = assumed_net_irrigation
            self.results_save[:,date, 2] = ep.sum(1)
            self.results_save[:,date, 3] = eta.sum(1)
            self.results_save[:,date, 4] = ro.sum(1)
            self.results_save[:,date, 5] = dp.sum(1)
        return self.results_save


    def dr_from_theta(self, theta):
        return (self.theta_fc - theta) * self.zr * self.zr_factor

    def theta_from_dr(self, dr):
        return self.theta_fc - dr / (self.zr * self.zr_factor)

    def ks(self, dr):
        # only one time
        result=torch.ones_like(self.p)
        result[torch.absolute(1-self.p) > 1e-6]=(self.taw[torch.absolute(1-self.p) > 1e-6] - dr[torch.absolute(1-self.p) > 1e-6]) / ((1 - self.p[torch.absolute(1-self.p) > 1e-6]) * self.taw[torch.absolute(1-self.p) > 1e-6])
        return torch.minimum(result, torch.tensor(1))

    def ro(self, effective_precipitation, theta_prev):
        result = (
            effective_precipitation
            + (theta_prev - self.theta_s) * self.zr * self.zr_factor
        )
        return torch.maximum(result, torch.tensor(0))

    def dp(self, theta_prev, peff):
        delta_t=self.delta_t
        theta_mm = theta_prev * self.zr * self.zr_factor+peff
        theta_fc_mm = self.theta_fc * self.zr * self.zr_factor
        theta_s_mm = self.theta_s * self.zr * self.zr_factor
        # theta = min(theta_mm, self.theta_s_mm)
        excess_water = theta_mm - theta_fc_mm + peff
        excess_water_limit=theta_s_mm-theta_fc_mm
        # need do under different conditions pjli 20220405
        excess_water[excess_water>excess_water_limit]=excess_water_limit[excess_water>excess_water_limit]
        dp_water=torch.zeros_like(excess_water)
        dp_water[excess_water <1.0E-7]=0.0
        th_now_mm=theta_fc_mm+excess_water*torch.exp(-self.Ksat/self.zr/excess_water_limit*delta_t)
        dp_water[excess_water >1.0E-7]=theta_mm[excess_water >1.0E-7]-th_now_mm[excess_water >1.0E-7]
        return torch.maximum(torch.minimum(dp_water,excess_water),torch.tensor(0)) # since excess-water may be negative

    def dr_without_irrig(self, dr_prev, theta_prev, ks, row):
        # "row" is a single row from self.timeseries
        ep = row[:,0,None]*self.delta_t
        eta= row[:,1,None]*self.delta_t * ks
        ro = self.ro(ep, theta_prev)
        dp = self.dp(theta_prev, ep)
        dr_noirr=dr_prev-(ep-ro)+eta+dp
        return dr_noirr.squeeze(),ep.squeeze(), eta.squeeze(),ro.squeeze(),dp.squeeze()

    def dr(self, dr_without_irrig, assumed_net_irrigation):
        if self.ngrid==1:
            dr_without_irrig=dr_without_irrig[None]
            result = dr_without_irrig - assumed_net_irrigation
        else:
            result = dr_without_irrig[:,None] - assumed_net_irrigation
        result = torch.minimum(result, self.taw)
        return result
        
# flatten the tensor into 
class Flatten(nn.Module):
   def forward(self, input):
       return input.view(input.size(0), -1)     
class ShallowCNN_Ks(nn.Module):
    '''
    A simple, general purpose, fully CNN network
    '''
    def __init__(self):
        # Perform initialization of the pytorch superclass
        super(ShallowCNN_Ks, self).__init__()
        # Define layer types
        self.cnn1 = nn.Conv2d(2, 10, kernel_size=5)
        self.cnn2 = nn.Conv2d(10, 20, kernel_size=2)
        # self.cnn3 = nn.Conv2d(20, 10, kernel_size=2)
        #self.cnn4 = nn.Conv2d(10, 2, kernel_size=5)
        self.linear1 = nn.Linear(20, 10)
        self.linear2 = nn.Linear(10, 1)
        # self.linear3 = nn.Linear(1,1)
        self.pool1=nn.AvgPool2d(2)
        self.drop=nn.Dropout2d()
        self.flat1 = Flatten()
        
    def forward(self, x):
        '''
        This method defines the network layering and activation functions
        '''
        x = self.cnn1(x)        # hidden layer
        x = self.pool1(x)       # pool function
        x = torch.nn.LeakyReLu(x)       # activation function
        # x = self.drop(x)
        
        x = self.cnn2(x)        # hidden layer
        x = self.pool1(x)       # pool function
        x = torch.nn.LeakyReLu(x)       # activation function
        
        # print(x.shape)
        # x = self.cnn3(x)        # hidden layer
        # x = self.pool1(x)       # pool function
        # x = torch.relu(x)       # activation function
        
        # x = self.cnn4(x)        # hidden layer
        # x = self.pool1(x)       # pool function
        # x = torch.relu(x)       # activation function
        
        x = self.flat1(x)
        
        x = self.linear1(x)     # hidden layer
        x = torch.nn.LeakyReLu(x)       # activation function

        x = self.linear2(x)     # output layer
        # x = torch.relu(x)     # activation function
        
        # x = self.linear3(x)
        
        # x = self.linear5(x) # output layer
        # x = torch.relu(x)       # activation function
        
        # x = self.linear6(x) # output layer
        
        return x
     
class ShallowCNN_ab(nn.Module):
    '''
    A simple, general purpose, fully CNN network
    '''
    def __init__(self):
        # Perform initialization of the pytorch superclass
        super(ShallowCNN_ab, self).__init__()
        # Define layer types
        self.cnn1 = nn.Conv2d(2, 10, kernel_size=5)
        self.cnn2 = nn.Conv2d(10, 20, kernel_size=2)
        # self.cnn3 = nn.Conv2d(20, 10, kernel_size=2)
        #self.cnn4 = nn.Conv2d(10, 2, kernel_size=5)
        self.linear1 = nn.Linear(20, 10)
        self.linear2 = nn.Linear(10, 2)
        self.pool1=nn.AvgPool2d(2)
        self.drop=nn.Dropout2d()
        self.flat1 = Flatten()
        
    def forward(self, x):
        '''
        This method defines the network layering and activation functions
        '''
        x = self.cnn1(x)        # hidden layer
        x = self.pool1(x)       # pool function
        # x = torch.nn.LeakyReLu(x)       # activation function
        x = torch.ReLu(x)
        # x = self.drop(x)
        
        x = self.cnn2(x)        # hidden layer
        x = self.pool1(x)       # pool function
        # x = torch.nn.LeakyReLu(x)       # activation function
        x = torch.ReLu(x)
        
        # print(x.shape)
        # x = self.cnn3(x)        # hidden layer
        # x = self.pool1(x)       # pool function
        # x = torch.relu(x)       # activation function
        
        # x = self.cnn4(x)        # hidden layer
        # x = self.pool1(x)       # pool function
        # x = torch.relu(x)       # activation function
        
        x = self.flat1(x)
        
        x = self.linear1(x)     # hidden layer
        # x = torch.nn.LeakyReLu(x)       # activation function
        x = torch.ReLu(x)
        
        x = self.linear2(x)     # output layer
        # x = torch.relu(x)     # activation function
        
        # x = self.linear5(x) # output layer
        # x = torch.relu(x)       # activation function
        
        # x = self.linear6(x) # output layer
        
        return x
    
    
class CNNModel_Inv(torch.nn.Module):
    def __init__(self):
        super(CNNModel_Inv, self).__init__()
        self.CnnKs=ShallowCNN_Ks()
        self.Cnnab=ShallowCNN_ab()
        self.phymodel=swb()
    def forward(self, x, fc_data,fix_para,scale_data):
        gen_Ks=self.CnnKs(x)
        gen_ab=self.Cnnab(x)
        gen_ori=torch.cat((gen_Ks,gen_ab),axis=1)
        # gen_ori [ngrid,[Ksat,a,b]]
        gen=gen_ori.clone()
        # gen[:,:-2] = torch.exp(gen_ori[:,:-2]).mul(scale_data[:,:-2,1])+scale_data[:,:-2,0] 
        gen[:,:] = gen_ori[:,:].mul(scale_data[:,:,1])+scale_data[:,:,0]
        # scale data [ngrid,[Ksat,a,b],[mean,std]]
        # gen [ngrid,[Ksat,a,b]]
        # cat fic_para
        para=torch.cat((fix_para,gen),dim=len(gen.shape)-1)
        # para=fix_para
        # fix_para [ngrid,[fix paras]]
        outCNN_all=self.phymodel(fc_data,para)
        outCNN_th=outCNN_all[:,:,0]
        # transform to ensemble smap
        nday=fc_data.shape[1]
        # gen[:,1]=1
        # gen[:,2]=0
        a=torch.tile(gen[:,1,np.newaxis],(1,nday))
        b=torch.tile(gen[:,2,np.newaxis],(1,nday))
        # a=torch.tile(para[:,-2,None],(1,nday))
        # b=torch.tile(para[:,-1,None],(1,nday))
        print(outCNN_th.device)
        print(a.device)
        outCNN_SMAP=outCNN_th*a+b
        # fc_data [ngrid,time,[ep,pet]]
        # outCNN_SMAP [ngrid,time,[theta]]
        return outCNN_SMAP,gen    

# %% model training
def train_batch(model, x, y, fc_data, fix_para, obs_loc,scale_data, optimizer, loss_fn):
    # Run forward calculation
    state_pred,param_pred = model.forward(x,fc_data, fix_para, scale_data)

    # Compute loss.
    loss = loss_fn(state_pred,param_pred, y,obs_loc)

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable weights
    # of the model)
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

    return loss.data.item()


def train(model, loader, optimizer, loss_fn, epochs):
    losses = list()
    
    # training via gpu
    if torch.cuda.is_available():
        loss_fn = loss_fn.cuda()
        model = model.cuda()
    batch_index = 0
    for e in range(epochs):
        # # release memory
        # torch.cuda.empty_cache()
        for x, y, fc_data, fix_para,obs_loc,scale_data in loader:
            # training via gpu
            if torch.cuda.is_available():    
                x=x.cuda()
                y=y.cuda()
                fc_data=fc_data.cuda()
                fix_para=fix_para.cuda()
                obs_loc=obs_loc.cuda()
                scale_data=scale_data.cuda()
            loss = train_batch(model=model, x=x, y=y, fc_data=fc_data, fix_para=fix_para, obs_loc=obs_loc, scale_data=scale_data,optimizer=optimizer, loss_fn=loss_fn)
            losses.append(loss)

            batch_index += 1
            end=time.time()
            print('batch_index:{},loss: {} use time {}'.format(batch_index,loss,end-begin))

        print("Epoch: ", e+1,"Batches: ", batch_index,"loss: ",loss)
        # print("Batches: ", batch_index)

    return losses

# %% model testing
def test_batch(model, x, y, fc_data, fix_para, obs_loc, scale_data):
    # Run forward calculation
    state_pred,param_pred = model.forward(x,fc_data, fix_para, scale_data)

    return state_pred,param_pred


def test(model, loader):
    state_pred_vectors = list()
    param_pred_vectors = list()

    batch_index = 0
    model=model.cpu()
    for x, y, fc_data, fix_para, obs_loc, scale_data in loader:
        # training via gpu
        # if torch.cuda.is_available():
        #     x=x.cuda()
        #     y=y.cuda()
        state_pred,param_pred = test_batch(model=model, x=x, y=y, fc_data=fc_data, fix_para=fix_para, obs_loc=obs_loc, scale_data=scale_data)

        state_pred_vectors.append(state_pred.data.cpu().numpy())
        param_pred_vectors.append(param_pred.data.cpu().numpy())

        batch_index += 1

    state_pred_vectors = np.concatenate(state_pred_vectors)
    param_pred_vectors = np.concatenate(param_pred_vectors)

    return state_pred_vectors,param_pred_vectors
        
# %%  custom loss function
class RangeBoundLoss(torch.nn.Module):# crit
    # limit parameters from going out of range pjli 20220417
    def __init__(self,lb,ub,factor=1):
        super(RangeBoundLoss, self).__init__()
        self.lb=torch.Tensor(lb).cuda()
        self.ub=torch.Tensor(ub).cuda()
        self.factor=torch.tensor(factor).cuda()
        # self.factor.requires_grad = False

    def forward(self, x, *args):
        # loss = self.factor* torch.relu(x-self.ub).sum()  # default
        loss = 0
        #pjli 20220408 x [ngrid,[Ksat,a,b]]
        loss = self.factor* torch.relu((x-self.ub)/(self.ub-self.lb)).sum()
        loss = loss+ self.factor* torch.relu((self.lb-x)/(self.ub-self.lb)).sum()
        return loss

# 1. 继承nn.Mdule
class My_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, state_pred, y, obs_loc):
        # print(y.shape) # torch.Size([16, 2])
        # return torch.mean(torch.pow(((fun_g1(x*data_std[1]+data_mean[1])-data_mean[2])/data_std[2] - y), 2))
        # return torch.mean(torch.pow((fun_g1(x*data_std[1]+data_mean[1])-data_mean[2])/data_std[2] - y[:,0,np.newaxis], 2))
        ny=state_pred.shape[0]
        loss = 0
        for k in range(ny):
            p=state_pred[k,obs_loc[k]==1]
            t=y[k]
            temp = torch.sqrt(((p - t)**2).mean())#**4*100*1000
            # temp = ((p - t)**2).mean()*100
            loss = loss + temp
            # loss_0=torch.sqrt(((state_pred[obs_loc==1] - y)**2).mean(1))
        # print(y.shape)
        # print(x.shape)
        # loss_1=torch.pow((fun_g1(x*data_std[2]+data_mean[2])-data_mean[3])/data_std[3] - y[:,1,np.newaxis], 2)
        # loss_2=torch.pow((fun_g2(x*data_std[2]+data_mean[2])-data_mean[4])/data_std[4] - y[:,2,np.newaxis], 2)
        # loss_3=torch.pow((fun_g3(x*data_std[1]+data_mean[1])-data_mean[4])/data_std[4] - y[:,3,np.newaxis], 2)
        # print(loss_1.shape) # (16,1)
        # if torch.mean(loss_1)<torch.mean(loss_2):
        # return (torch.pow(torch.mean(loss_1),2)+torch.pow(torch.mean(loss_2),2))/2
        # return (torch.mean(loss_1)+torch.mean(loss_2))/2
        return loss
        # return torch.mean(torch.cat((loss_1,loss_2),0))
        
class sumOfLoss(torch.nn.Module):
    # add multiple loss functions together.
    # inputs are many modules
    def __init__(self,*args,factor=1):
        super(sumOfLoss, self).__init__()
        self.lossFuncs = []
        for arg in args:
            self.lossFuncs.append(arg.cuda())

    def forward(self, state_pred, para_pred, y,obs_loc):
    # multiple inputs must be bundled into lists
        loss = 0
        for i in range(0,len(self.lossFuncs)):
            func = self.lossFuncs[i]
            if torch.is_tensor(y):
                if i==0:# state calibration pjli20220407
                    loss = loss+ func(state_pred.to(device='cuda:0'),y,obs_loc)
                    print('state loss {}'.format(loss))
                else: # parameter calibration
                    loss = loss+ func(para_pred.to(device='cuda:0'),None)
                    print('para loss {}'.format(func(para_pred.to(device='cuda:0'),None)))
            else:
                print('sum of loss: torch.is_tensor(y) false')
            # can this be written in vector format?
        return loss
    



# %% model running    
def run(dataset_train, dataset_test,batch_size_train=100, learning_rate =1e-3, epoch=300):
    # Define the hyperparameters
    # **batch_size_train** Batch size is the number of training examples used to calculate each iteration's gradient
    # **learning_rate**
    # **epoch**
    
    data_loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size_train, shuffle=True)
    data_loader_fit = DataLoader(dataset=dataset_train, batch_size=len(dataset_train), shuffle=False)
    data_loader_test = DataLoader(dataset=dataset_test, batch_size=len(dataset_test), shuffle=False)
    

    shallow_model = CNNModel_Inv()
    
    # Initialize the optimizer with above parameters
    optimizer = optim.Adam(shallow_model.parameters(), lr=learning_rate)
    # optimizer = optim.SGD(shallow_model.parameters(), lr=0.001, momentum=0.9)

    # Define the loss function
    # lossFun_Inv1 = nn.MSELoss()  # mean squared error
    lossFun_Inv1 = My_loss()
    ub = data_para_range[:,0]
    lb = data_para_range[:,1]
    lossFun_Inv2 = RangeBoundLoss(lb=lb, ub=ub, factor=1)
    loss_fn = sumOfLoss(lossFun_Inv1, lossFun_Inv2)
    
    # Train and get the resulting loss per iteration
    loss = train(model=shallow_model, loader=data_loader_train, optimizer=optimizer, loss_fn=loss_fn,epochs=epoch)
    torch.save(shallow_model,path_data+'model.pth')
    #  loss=0
    # Test and get the resulting predicted y values
    y_fit,para_fit = test(model=shallow_model, loader=data_loader_fit)
    y_predict, para_predict= test(model=shallow_model, loader=data_loader_test)
    
    # torch.save(shallow_model,path_data+'model.pth')

    return loss, y_fit, para_fit, y_predict, para_predict

# %% model running 
batch_size_train=512  
epoch=5
n_batch=int(x_train.shape[0]/batch_size_train)
losses, y_fit, para_fit, y_pred, para_pred = run(dataset_train=dataset_train, dataset_test=dataset_test,
                                      batch_size_train=batch_size_train,
                                      epoch=epoch)

# reback function renomalization
def fun_reback(data, mean, std):
    return data*std+mean
x_train_reback=fun_reback(x_train,x_mean,x_std)
x_test_reback=fun_reback(x_test,x_mean,x_std)
y_train_reback=y_train
y_test_reback=y_test
y_fit_reback=y_fit
y_pred_reback=y_pred

# %% plot loss values
def plot_loss(losses, show=True):
    fig = plt.gcf()
    fig.set_size_inches(8,6)
    ax = plt.axes()
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    x_loss = list(range(len(losses)))
    plt.plot(x_loss, losses)

    if show:
        plt.show()

    plt.close()

plot_loss(losses)

print("Final loss:", sum(losses[-n_batch:])/n_batch)

np.savez(path_data+'data02_result.npz',
         x_train=x_train_reback,
         x_test=x_test_reback,
         y_train=y_train_reback,
         y_test=y_test_reback,
         y_fit=y_fit_reback,
         para_fit = para_fit,
         y_pred=y_pred_reback,
         para_pred=para_pred,
         loss=losses)
